{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess   = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "sm = boto3.Session().client(service_name='sagemaker', region_name=region)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify the S3 Location of the Features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r scikit_processing_job_s3_output_prefix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Previous Scikit Processing Job Name: sagemaker-scikit-learn-2020-03-30-03-34-18-188\n"
     ]
    }
   ],
   "source": [
    "print('Previous Scikit Processing Job Name: {}'.format(scikit_processing_job_s3_output_prefix))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix_train = '{}/output/bert-train'.format(scikit_processing_job_s3_output_prefix)\n",
    "prefix_validation = '{}/output/bert-validation'.format(scikit_processing_job_s3_output_prefix)\n",
    "prefix_test = '{}/output/bert-test'.format(scikit_processing_job_s3_output_prefix)\n",
    "\n",
    "path_train = './{}'.format(prefix_train)\n",
    "path_validation = './{}'.format(prefix_validation)\n",
    "path_test = './{}'.format(prefix_test)\n",
    "\n",
    "train_s3_uri = 's3://{}/{}'.format(bucket, prefix_train)\n",
    "validation_s3_uri = 's3://{}/{}'.format(bucket, prefix_validation)\n",
    "test_s3_uri = 's3://{}/{}'.format(bucket, prefix_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-835319576252/sagemaker-scikit-learn-2020-03-30-03-34-18-188/output/bert-train', 'S3DataDistributionType': 'FullyReplicated'}}}\n",
      "{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-835319576252/sagemaker-scikit-learn-2020-03-30-03-34-18-188/output/bert-validation', 'S3DataDistributionType': 'FullyReplicated'}}}\n",
      "{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-835319576252/sagemaker-scikit-learn-2020-03-30-03-34-18-188/output/bert-test', 'S3DataDistributionType': 'FullyReplicated'}}}\n"
     ]
    }
   ],
   "source": [
    "s3_input_train_data = sagemaker.s3_input(s3_data=train_s3_uri) #, content_type='text/csv')\n",
    "s3_input_validation_data = sagemaker.s3_input(s3_data=validation_s3_uri) #, content_type='text/csv')\n",
    "s3_input_test_data = sagemaker.s3_input(s3_data=test_s3_uri) #, content_type='text/csv')\n",
    "\n",
    "print(s3_input_train_data.config)\n",
    "print(s3_input_validation_data.config)\n",
    "print(s3_input_test_data.config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "import sys\r\n",
      "import subprocess\r\n",
      "import argparse\r\n",
      "import json\r\n",
      "#subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'tensorflow-gpu==2.2.0-rc2'])\r\n",
      "subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'bert-for-tf2'])\r\n",
      "subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'sentencepiece'])\r\n",
      "\r\n",
      "import tensorflow as tf\r\n",
      "print(tf.__version__)\r\n",
      "\r\n",
      "import boto3\r\n",
      "import pandas as pd\r\n",
      "\r\n",
      "import os\r\n",
      "import math\r\n",
      "import datetime\r\n",
      "\r\n",
      "from tqdm import tqdm\r\n",
      "\r\n",
      "import pandas as pd\r\n",
      "import numpy as np\r\n",
      "\r\n",
      "import tensorflow as tf\r\n",
      "from tensorflow import keras\r\n",
      "from glob import glob \r\n",
      "\r\n",
      "from bert.model import BertModelLayer\r\n",
      "from bert.loader import StockBertConfig, map_stock_config_to_params, load_stock_weights\r\n",
      "from bert.tokenization.bert_tokenization import FullTokenizer\r\n",
      "\r\n",
      "from sklearn.metrics import confusion_matrix, classification_report\r\n",
      "\r\n",
      "#train = pd.read_csv('./data/amazon_reviews_us_Digital_Software_v1_00.tsv.gz', delimiter='\\t')[:100]\r\n",
      "#test = pd.read_csv('./data/amazon_reviews_us_Digital_Software_v1_00.tsv.gz', delimiter='\\t')[:100]\r\n",
      "\r\n",
      "#train.shape\r\n",
      "#train.head()\r\n",
      "\r\n",
      "import os\r\n",
      "\r\n",
      "os.system('rm uncased_L-12_H-768_A-12.zip')\r\n",
      "os.system('rm -rf uncased_L-12_H-768_A-12')\r\n",
      "\r\n",
      "os.system('wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip')\r\n",
      "\r\n",
      "#os.system('unzip uncased_L-12_H-768_A-12.zip')\r\n",
      "\r\n",
      "import zipfile\r\n",
      "with zipfile.ZipFile('uncased_L-12_H-768_A-12.zip', 'r') as zip_ref:\r\n",
      "  zip_ref.extractall('.')\r\n",
      "\r\n",
      "\r\n",
      "os.system('ls -al ./uncased_L-12_H-768_A-12')\r\n",
      "#subprocess.check_call([sys.executable, 'unzip', '-f', 'uncased_L-12_H-768_A-12.zip'])\r\n",
      "#subprocess.check_call([sys.executable, 'ls', '-al', './model/uncased_L-12_H-768_A-12'])\r\n",
      "\r\n",
      "bert_ckpt_dir = './uncased_L-12_H-768_A-12'\r\n",
      "bert_ckpt_file = os.path.join(bert_ckpt_dir, \"bert_model.ckpt\")\r\n",
      "bert_config_file = os.path.join(bert_ckpt_dir, \"bert_config.json\")\r\n",
      "\r\n",
      "CLASSES=[1, 2, 3, 4, 5]\r\n",
      "MAX_SEQ_LEN=128\r\n",
      "BATCH_SIZE=128\r\n",
      "EPOCHS=2\r\n",
      "STEPS_PER_EPOCH=1000\r\n",
      "\r\n",
      "def select_data_and_label_from_record(record):\r\n",
      "    x = {\r\n",
      "        'input_word_ids': record['input_ids'],\r\n",
      "        'input_mask': record['input_mask'],\r\n",
      "        'input_type_ids': record['segment_ids']\r\n",
      "    }\r\n",
      "    y = record['label_ids']\r\n",
      "\r\n",
      "    return (x, y)\r\n",
      "\r\n",
      "\r\n",
      "def file_based_input_dataset_builder(input_file, \r\n",
      "                                     seq_length, \r\n",
      "                                     is_training,\r\n",
      "                                     drop_remainder):\r\n",
      "\r\n",
      "  name_to_features = {\r\n",
      "      \"input_ids\": tf.io.FixedLenFeature([seq_length], tf.int64),\r\n",
      "      \"input_mask\": tf.io.FixedLenFeature([seq_length], tf.int64),\r\n",
      "      \"segment_ids\": tf.io.FixedLenFeature([seq_length], tf.int64),\r\n",
      "      \"label_ids\": tf.io.FixedLenFeature([], tf.int64),\r\n",
      "      \"is_real_example\": tf.io.FixedLenFeature([], tf.int64),\r\n",
      "  }\r\n",
      "\r\n",
      "  def _decode_record(record, name_to_features):\r\n",
      "    \"\"\"Decodes a record to a TensorFlow example.\"\"\"\r\n",
      "    example = tf.io.parse_single_example(record, name_to_features)\r\n",
      "\r\n",
      "    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.\r\n",
      "    # So cast all int64 to int32.\r\n",
      "    for name in list(example.keys()):\r\n",
      "      t = example[name]\r\n",
      "      if t.dtype == tf.int64:\r\n",
      "        t = tf.cast(t, tf.int32)\r\n",
      "      example[name] = t\r\n",
      "\r\n",
      "    return example\r\n",
      "\r\n",
      "#  def input_fn(params):\r\n",
      "#  \"\"\"The actual input function.\"\"\"\r\n",
      "#  batch_size = params[\"batch_size\"]\r\n",
      "\r\n",
      "  # For training, we want a lot of parallel reading and shuffling.\r\n",
      "  # For eval, we want no shuffling and parallel reading doesn't matter.\r\n",
      "  dataset = tf.data.TFRecordDataset(input_file)\r\n",
      "  if is_training:\r\n",
      "    dataset = dataset.repeat()\r\n",
      "    dataset = dataset.shuffle(buffer_size=100)\r\n",
      "\r\n",
      "  dataset = dataset.apply(\r\n",
      "      tf.data.experimental.map_and_batch(\r\n",
      "          lambda record: _decode_record(record, name_to_features),\r\n",
      "          batch_size=BATCH_SIZE,\r\n",
      "          drop_remainder=drop_remainder))\r\n",
      "\r\n",
      "  return dataset\r\n",
      "\r\n",
      "#  return input_fn\r\n",
      "\r\n",
      "\r\n",
      "# class ClassificationData:\r\n",
      "#   TEXT_COLUMN = 'review_body'\r\n",
      "#   LABEL_COLUMN = 'star_rating'\r\n",
      "\r\n",
      "#   def __init__(self, train, test, tokenizer: FullTokenizer, classes, max_seq_len=192):\r\n",
      "#     self.tokenizer = tokenizer\r\n",
      "#     self.max_seq_len = 0\r\n",
      "#     self.classes = classes\r\n",
      "    \r\n",
      "#     ((self.train_x, self.train_y), (self.test_x, self.test_y)) = map(self._prepare, [train, test])\r\n",
      "\r\n",
      "# #    print('max seq_len', self.max_seq_len)\r\n",
      "#     self.max_seq_len = min(self.max_seq_len, max_seq_len)\r\n",
      "#     self.train_x, self.test_x = map(self._pad, [self.train_x, self.test_x])\r\n",
      "\r\n",
      "#   def _prepare(self, df):\r\n",
      "#     x, y = [], []\r\n",
      "    \r\n",
      "#     for _, row in tqdm(df.iterrows()):\r\n",
      "#       text, label = row[ClassificationData.TEXT_COLUMN], row[ClassificationData.LABEL_COLUMN]\r\n",
      "#       tokens = self.tokenizer.tokenize(text)\r\n",
      "#       tokens = [\"[CLS]\"] + tokens + [\"[SEP]\"]\r\n",
      "#       token_ids = self.tokenizer.convert_tokens_to_ids(tokens)\r\n",
      "#       self.max_seq_len = max(self.max_seq_len, len(token_ids))\r\n",
      "#       x.append(token_ids)\r\n",
      "#       y.append(self.classes.index(label))\r\n",
      "\r\n",
      "#     return np.array(x), np.array(y)\r\n",
      "\r\n",
      "#   def _pad(self, ids):\r\n",
      "#     x = []\r\n",
      "#     for input_ids in ids:\r\n",
      "#       input_ids = input_ids[:min(len(input_ids), self.max_seq_len - 2)]\r\n",
      "#       input_ids = input_ids + [0] * (self.max_seq_len - len(input_ids))\r\n",
      "#       x.append(np.array(input_ids))\r\n",
      "#     return np.array(x)\r\n",
      "\r\n",
      "tokenizer = FullTokenizer(vocab_file=os.path.join(bert_ckpt_dir, \"vocab.txt\"))\r\n",
      "\r\n",
      "tokenizer.tokenize(\"I can't wait to visit Bulgaria again!\")\r\n",
      "\r\n",
      "tokens = tokenizer.tokenize(\"I can't wait to visit Bulgaria again!\")\r\n",
      "tokenizer.convert_tokens_to_ids(tokens)\r\n",
      "\r\n",
      "\r\n",
      "def flatten_layers(root_layer):\r\n",
      "    if isinstance(root_layer, keras.layers.Layer):\r\n",
      "        yield root_layer\r\n",
      "    for layer in root_layer._layers:\r\n",
      "        for sub_layer in flatten_layers(layer):\r\n",
      "            yield sub_layer\r\n",
      "\r\n",
      "\r\n",
      "def freeze_bert_layers(l_bert):\r\n",
      "    \"\"\"\r\n",
      "    Freezes all but LayerNorm and adapter layers - see arXiv:1902.00751.\r\n",
      "    \"\"\"\r\n",
      "    for layer in flatten_layers(l_bert):\r\n",
      "        if layer.name in [\"LayerNorm\", \"adapter-down\", \"adapter-up\"]:\r\n",
      "            layer.trainable = True\r\n",
      "        elif len(layer._layers) == 0:\r\n",
      "            layer.trainable = False\r\n",
      "        l_bert.embeddings_layer.trainable = False\r\n",
      "\r\n",
      "\r\n",
      "def create_learning_rate_scheduler(max_learn_rate=5e-5,\r\n",
      "                                   end_learn_rate=1e-7,\r\n",
      "                                   warmup_epoch_count=10,\r\n",
      "                                   total_epoch_count=90):\r\n",
      "\r\n",
      "    def lr_scheduler(epoch):\r\n",
      "        if epoch < warmup_epoch_count:\r\n",
      "            res = (max_learn_rate/warmup_epoch_count) * (epoch + 1)\r\n",
      "        else:\r\n",
      "            res = max_learn_rate*math.exp(math.log(end_learn_rate/max_learn_rate)*(epoch-warmup_epoch_count+1)/(total_epoch_count-warmup_epoch_count+1))\r\n",
      "        return float(res)\r\n",
      "    learning_rate_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_scheduler, verbose=1)\r\n",
      "\r\n",
      "    return learning_rate_scheduler\r\n",
      "\r\n",
      "\r\n",
      "def create_model(max_seq_len, bert_ckpt_file, adapter_size):\r\n",
      "\r\n",
      "  with tf.io.gfile.GFile(bert_config_file, \"r\") as reader:\r\n",
      "    bc = StockBertConfig.from_json_string(reader.read())\r\n",
      "    bert_params = map_stock_config_to_params(bc)\r\n",
      "    bert_params.adapter_size = adapter_size \r\n",
      "    bert = BertModelLayer.from_params(bert_params, name=\"bert\")\r\n",
      "        \r\n",
      "  input_ids = keras.layers.Input(shape=(max_seq_len, ), dtype='int32', name=\"input_ids\")\r\n",
      "  bert_output = bert(input_ids)\r\n",
      "\r\n",
      "  print(\"bert shape\", bert_output.shape)\r\n",
      "\r\n",
      "  cls_out = keras.layers.Lambda(lambda seq: seq[:, 0, :])(bert_output)\r\n",
      "  cls_out = keras.layers.Dropout(0.5)(cls_out)\r\n",
      "  logits = keras.layers.Dense(units=768, activation=\"tanh\")(cls_out)\r\n",
      "  logits = keras.layers.Dropout(0.5)(logits)\r\n",
      "  logits = keras.layers.Dense(units=len(CLASSES), activation=\"softmax\")(logits)\r\n",
      "\r\n",
      "  model = keras.Model(inputs=input_ids, outputs=logits)\r\n",
      "  model.build(input_shape=(None, max_seq_len))\r\n",
      "\r\n",
      "  load_stock_weights(bert, bert_ckpt_file)\r\n",
      "\r\n",
      "  if adapter_size is not None:\r\n",
      "    freeze_bert_layers(bert)\r\n",
      "\r\n",
      "  return model\r\n",
      "\r\n",
      "\r\n",
      "if __name__ == '__main__':\r\n",
      "    parser = argparse.ArgumentParser()\r\n",
      "\r\n",
      "#    parser.add_argument('--model-type', type=str, default='bert')\r\n",
      "#    parser.add_argument('--model-name', type=str, default='bert-base-cased')\r\n",
      "    parser.add_argument('--train-data', type=str, default=os.environ['SM_CHANNEL_TRAIN'])\r\n",
      "    parser.add_argument('--validation-data', type=str, default=os.environ['SM_CHANNEL_VALIDATION'])\r\n",
      "    parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])\r\n",
      "    parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))\r\n",
      "    parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])\r\n",
      "    parser.add_argument('--num-gpus', type=int, default=os.environ['SM_NUM_GPUS'])\r\n",
      "\r\n",
      "    args, _ = parser.parse_known_args()   \r\n",
      "#    model_type = args.model_type\r\n",
      "#    model_name = args.model_name\r\n",
      "    train_data = args.train_data\r\n",
      "    validation_data = args.validation_data\r\n",
      "    model_dir = args.model_dir\r\n",
      "    hosts = args.hosts\r\n",
      "    current_host = args.current_host\r\n",
      "    num_gpus = args.num_gpus\r\n",
      "\r\n",
      "    # features = ClassificationData(train, test, tokenizer, classes, max_seq_len=128)\r\n",
      "    # features.train_x.shape\r\n",
      "    # features.train_x[0]\r\n",
      "    # features.train_y[0]\r\n",
      "    # features.max_seq_len\r\n",
      "\r\n",
      "    adapter_size = None # Change to 64?\r\n",
      "    model = create_model(MAX_SEQ_LEN, bert_ckpt_file, adapter_size)\r\n",
      "\r\n",
      "    model.summary()\r\n",
      "\r\n",
      "    model.compile(\r\n",
      "      optimizer=keras.optimizers.Adam(1e-5),\r\n",
      "      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\r\n",
      "      metrics=[keras.metrics.SparseCategoricalAccuracy(name=\"acc\")]\r\n",
      "    )\r\n",
      "\r\n",
      "    log_dir = \"log/classification/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%s\")\r\n",
      "    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir)\r\n",
      "\r\n",
      "    train_data_filenames = glob('{}/*.tfrecord'.format(train_data))\r\n",
      "    print(train_data_filenames)\r\n",
      "\r\n",
      "    # Create an input function for training. drop_remainder = True for using TPUs.\r\n",
      "    train_dataset = file_based_input_dataset_builder(\r\n",
      "        train_data_filenames,\r\n",
      "        seq_length=MAX_SEQ_LEN,\r\n",
      "        is_training=True,\r\n",
      "        drop_remainder=False)\r\n",
      "\r\n",
      "    print('*********** {}'.format(train_dataset))\r\n",
      "\r\n",
      "#    (train_dataset_X, train_dataset_y) = train_dataset.map(select_data_and_label_from_record)\r\n",
      "    train_dataset_2 = train_dataset.map(select_data_and_label_from_record)\r\n",
      "    print(train_dataset_2)\r\n",
      "\r\n",
      "    iterator = iter(train_dataset_2)\r\n",
      "    next_element = iterator.get_next()\r\n",
      "    print('*********** {}'.format(next_element))\r\n",
      "\r\n",
      "#    if is_training:\r\n",
      "#        dataset = dataset.shuffle(100)\r\n",
      "#        dataset = dataset.repeat()\r\n",
      "\r\n",
      "#    dataset = dataset.batch(batch_size, drop_remainder=is_training)\r\n",
      "#    dataset = dataset.prefetch(1024)\r\n",
      "\r\n",
      "#    iterator = iter(train_dataset_2)\r\n",
      "#    next_element = iterator.get_next()\r\n",
      "#    print('*********** {}'.format(next_element))\r\n",
      "\r\n",
      "    history = model.fit(\r\n",
      "      train_dataset_2,\r\n",
      "#       x=train_dataset_X,\r\n",
      "#       y=train_dataset_y,\r\n",
      "#      train_dataset.batch(10),\r\n",
      "#      train_dataset,\r\n",
      "#      x=features.train_x, \r\n",
      "#      y=features.train_y,\r\n",
      "#      validation_split=0.1,\r\n",
      "      batch_size=BATCH_SIZE,\r\n",
      "      shuffle=True,\r\n",
      "      epochs=EPOCHS,\r\n",
      "      steps_per_epoch=STEPS_PER_EPOCH,\r\n",
      "      callbacks=[tensorboard_callback]\r\n",
      "    )\r\n",
      "\r\n",
      "#    _, train_acc = model.evaluate(features.train_x, features.train_y)\r\n",
      "#    _, test_acc = model.evaluate(features.test_x, features.test_y)\r\n",
      "\r\n",
      "#    print(\"train acc\", train_acc)\r\n",
      "#    print(\"test acc\", test_acc)\r\n",
      "\r\n",
      "#    y_pred = model.predict(features.test_x).argmax(axis=-1)\r\n",
      "\r\n",
      "#    print(classification_report(features.test_y, y_pred)) #, target_names=classes))\r\n",
      "\r\n",
      "#    cm = confusion_matrix(features.test_y, y_pred)\r\n",
      "#    df_cm = pd.DataFrame(cm, index=classes, columns=classes)\r\n",
      "\r\n",
      "    sentences = [\r\n",
      "      \"This is just OK.\",\r\n",
      "      \"This sucks.\",\r\n",
      "      \"This is great.\"\r\n",
      "    ]\r\n",
      "\r\n",
      "    pred_tokens = map(tokenizer.tokenize, sentences)\r\n",
      "    pred_tokens = map(lambda tok: [\"[CLS]\"] + tok + [\"[SEP]\"], pred_tokens)\r\n",
      "    pred_token_ids = list(map(tokenizer.convert_tokens_to_ids, pred_tokens))\r\n",
      "\r\n",
      "    pred_token_ids = map(lambda tids: tids +[0]*(MAX_SEQ_LEN-len(tids)),pred_token_ids)\r\n",
      "    pred_token_ids = np.array(list(pred_token_ids))\r\n",
      "\r\n",
      "    predictions = model.predict(pred_token_ids).argmax(axis=-1)\r\n",
      "\r\n",
      "    for review_body, star_rating in zip(sentences, predictions):\r\n",
      "       print(\"review_body:\", review_body, \"\\star_rating:\", CLASSES[star_rating])\r\n",
      "       print()\r\n",
      "\r\n",
      "#    model.save('/opt/ml/model/0/', save_format='tf')\r\n",
      "#    model.save('/opt/ml/model/bert_reviews.h5')\r\n"
     ]
    }
   ],
   "source": [
    "!cat src_bert_tf2/tf_bert_reviews.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.tensorflow import TensorFlow\n",
    "\n",
    "model_output_path = 's3://{}/models/tf2-bert'.format(bucket)\n",
    "\n",
    "bert_estimator = TensorFlow(entry_point='tf_bert_reviews.py',\n",
    "                         source_dir='src_bert_tf2',\n",
    "                         role=role,\n",
    "                         train_instance_count=1, # 1 is actually faster due to communication overhead with >1\n",
    "                         train_instance_type='ml.p3.16xlarge',\n",
    "                         py_version='py3',\n",
    "                         framework_version='2.0.0',\n",
    "                         output_path=model_output_path,\n",
    "#                         hyperparameters={'model_type':'bert',\n",
    "#                                          'model_name': 'bert-base-cased'},\n",
    "                         enable_cloudwatch_metrics=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_estimator.fit(inputs={'train': s3_input_train_data, \n",
    "                           'validation': s3_input_validation_data,}, \n",
    "                   wait=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training_job_name:  tensorflow-training-2020-03-31-05-58-21-428\n"
     ]
    }
   ],
   "source": [
    "training_job_name = bert_estimator.latest_training_job.name\n",
    "print('training_job_name:  {}'.format(training_job_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# from sagemaker.tensorflow import TensorFlow\n",
    "\n",
    "# bert_estimator = TensorFlow.attach(training_job_name=training_job_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<b>Review <a href=\"https://console.aws.amazon.com/sagemaker/home?region=us-east-1#/jobs/tensorflow-training-2020-03-31-05-58-21-428\">Training Job</a> After About 5 Minutes</b>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}\">Training Job</a> After About 5 Minutes</b>'.format(region, training_job_name)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<b>Review <a href=\"https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/aws/sagemaker/TrainingJobs;prefix=tensorflow-training-2020-03-31-05-58-21-428;streamFilter=typeLogStreamPrefix\">CloudWatch Logs</a> After About 5 Minutes</b>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix\">CloudWatch Logs</a> After About 5 Minutes</b>'.format(region, training_job_name)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<b>Review <a href=\"https://s3.console.aws.amazon.com/s3/buckets/sagemaker-us-east-1-835319576252/models/tf2-bert/tensorflow-training-2020-03-31-05-58-21-428/?region=us-east-1&tab=overview\">S3 Output Data</a> After The Training Job Has Completed</b>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "training_job_s3_output_prefix = 'models/tf2-bert/{}'.format(training_job_name) # 'models/tf-bert/script-mode/training-runs/{}'.format(training_job_name)\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview\">S3 Output Data</a> After The Training Job Has Completed</b>'.format(bucket, training_job_s3_output_prefix, region)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download and Load the Trained Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# download the model artifact from AWS S3\n",
    "\n",
    "!aws s3 cp $model_output_path/$training_job_name/output/model.tar.gz ./models/tf2-bert/\n",
    "\n",
    "#!aws s3 cp s3://sagemaker-us-east-1-835319576252/models/tf-bert/script-mode/training-runs/tensorflow-training-2020-03-24-04-41-39-405/output/model.tar.gz ./models/tf2-bert/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tarfile\n",
    "import pickle as pkl\n",
    "\n",
    "tar = tarfile.open('./models/tf2-bert/model.tar.gz')\n",
    "tar.extractall(path='./models/tf2-bert')\n",
    "tar.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "!ls -al ./models/tf2-bert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls -al ./models/tf2-bert/0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Must upgrade wrapt before installing TF\n",
    "!pip install -q wrapt --upgrade --ignore-installed\n",
    "!pip install -q tensorflow==1.15.2\n",
    "!pip install -q tensorflow-hub==0.7.0\n",
    "!pip install -q bert-tensorflow==1.0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "saved_model = tf.saved_model.load_v2(\n",
    "    './models/tf2-bert/0',\n",
    "    tags=None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO:  Load the Keras model (or weights?) using the most appopriate mechanism (not necessarily this way) \n",
    "#loaded_model = create_model(data.max_seq_len, adapter_size=None)\n",
    "#loaded_model.load_weights(\"movie_reviews.h5\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, train_acc = loaded_model.evaluate(data.train_x, data.train_y)\n",
    "_, test_acc = loaded_model.evaluate(data.test_x, data.test_y)\n",
    "\n",
    "print(\"train acc\", train_acc)\n",
    "print(\" test acc\", test_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "inference = saved_model.signatures[\"serving_default\"]\n",
    "print(inference.inputs)\n",
    "print(inference.structured_outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predict \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bert import tokenization\n",
    "import tensorflow_hub as hub\n",
    "\n",
    "BERT_MODEL_HUB = \"https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1\"\n",
    "\n",
    "def create_tokenizer_from_hub_module():\n",
    "    \"\"\"Get the vocab file and casing info from the Hub module.\"\"\"\n",
    "    with tf.Graph().as_default():\n",
    "        bert_module = hub.Module(BERT_MODEL_HUB)\n",
    "        tokenization_info = bert_module(signature=\"tokenization_info\", as_dict=True)\n",
    "        with tf.Session() as sess:\n",
    "            vocab_file, do_lower_case = sess.run([tokenization_info[\"vocab_file\"],\n",
    "                                                tokenization_info[\"do_lower_case\"]])\n",
    "      \n",
    "        return tokenization.FullTokenizer(vocab_file=vocab_file,\n",
    "                                               do_lower_case=do_lower_case)\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_predict_features(features, seq_length):\n",
    "  all_input_ids = []\n",
    "  all_input_mask = []\n",
    "  all_segment_ids = []\n",
    "  all_label_ids = []\n",
    "\n",
    "  for feature in features:\n",
    "    all_input_ids.append(feature.input_ids)\n",
    "    all_input_mask.append(feature.input_mask)\n",
    "    all_segment_ids.append(feature.segment_ids)\n",
    "    all_label_ids.append(feature.label_id)\n",
    "\n",
    "    batch_size = 32\n",
    "\n",
    "  num_examples = len(features)\n",
    "\n",
    "  # This is for demo purposes and does NOT scale to large data sets. We do\n",
    "  # not use Dataset.from_generator() because that uses tf.py_func which is\n",
    "  # not TPU compatible. The right way to load data is with TFRecordReader.\n",
    "#   d = tf.data.Dataset.from_tensor_slices({\n",
    "#     \"input_ids\":\n",
    "#         tf.constant(\n",
    "#             all_input_ids, shape=[num_examples, seq_length],\n",
    "#             dtype=tf.int32),\n",
    "#     \"input_mask\":\n",
    "#         tf.constant(\n",
    "#             all_input_mask,\n",
    "#             shape=[num_examples, seq_length],\n",
    "#             dtype=tf.int32),\n",
    "#     \"segment_ids\":\n",
    "#         tf.constant(\n",
    "#             all_segment_ids,\n",
    "#             shape=[num_examples, seq_length],\n",
    "#             dtype=tf.int32),\n",
    "#     \"label_ids\":\n",
    "#         tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32),\n",
    "#   })\n",
    "\n",
    "#   d = d.batch(batch_size=batch_size, drop_remainder=False)\n",
    "\n",
    "#  return d\n",
    "\n",
    "  input_ids = tf.constant(\n",
    "             all_input_ids, shape=[num_examples, seq_length],\n",
    "             dtype=tf.int32)\n",
    "\n",
    "  input_mask = tf.constant(\n",
    "             all_input_mask,\n",
    "             shape=[num_examples, seq_length],\n",
    "             dtype=tf.int32)\n",
    "\n",
    "  segment_ids = tf.constant(\n",
    "             all_segment_ids,\n",
    "             shape=[num_examples, seq_length],\n",
    "             dtype=tf.int32)\n",
    "\n",
    "  label_ids = tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32),\n",
    "\n",
    "  return input_ids, input_mask, segment_ids, label_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src_bert_tf import amazon_run_classifier\n",
    "\n",
    "MAX_SEQ_LENGTH = 128\n",
    "LABEL_VALUES = [1, 2, 3, 4, 5]\n",
    "\n",
    "def predict(in_sentences):\n",
    "    tokenizer = create_tokenizer_from_hub_module()\n",
    "    print('**** TOKENIZER {}****'.format(tokenizer))\n",
    "    \n",
    "    input_examples = [amazon_run_classifier.InputExample(guid=\"\", text_a = x, text_b = None, label = -1) for x in in_sentences]\n",
    "    input_features = amazon_run_classifier.convert_examples_to_features(input_examples, LABEL_VALUES, MAX_SEQ_LENGTH, tokenizer)\n",
    "\n",
    "#    predict_input_fn = amazon_run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)\n",
    "#    predictions = estimator.predict(predict_input_fn)\n",
    "    input_ids, input_mask, segment_ids, label_ids = get_predict_features(input_features, MAX_SEQ_LENGTH)\n",
    "    print(type(input_ids))\n",
    "\n",
    "    inference(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids)\n",
    "\n",
    "    return [(sentence, prediction['probabilities'], LABEL_VALUES[prediction['labels']]) for sentence, prediction in zip(in_sentences, predictions)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "pred_sentences = [\n",
    "  \"That movie was absolutely awful\",\n",
    "  \"The acting was a bit lacking\",\n",
    "  \"The film was creative and surprising\",\n",
    "  \"Absolutely fantastic!\"\n",
    "]\n",
    "\n",
    "np_list = np.asarray(pred_sentences)\n",
    "tensor_list = tf.convert_to_tensor(np_list)\n",
    "predictions = predict(tensor_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
